- 
                Notifications
    You must be signed in to change notification settings 
- Fork 926
bug fix: RuntimeError when training GRPO with LoRA and PtEngine #5645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @chenjianhuii, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses a critical bug that prevented successful training of models using GRPO with LoRA when PtEngine was involved. The core issue stemmed from an incompatibility between PyTorch's inference mode and the dynamic re-enabling of PEFT layers. The fix ensures smooth operation of the training pipeline by adjusting the gradient context during inference.
Highlights
- Fix for RuntimeError in GRPO training with LoRA: Resolves a RuntimeError that occurred when training GRPO with LoRA and PtEngine, specifically due to a conflict between torch.inference_mode() and the re-enabling of PEFT layers.
- Change from torch.inference_mode() to torch.no_grad(): The infer method in pt_engine.py now uses @torch.no_grad() instead of @torch.inference_mode() to prevent the RuntimeError when requires_grad(True) is called on tensors previously used in inference mode.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description | 
|---|---|---|
| Code Review | /gemini review | Performs a code review for the current pull request in its current state. | 
| Pull Request Summary | /gemini summary | Provides a summary of the current pull request in its current state. | 
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. | 
| Help | /gemini help | Displays a list of available commands. | 
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
- 
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩ 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly resolves a RuntimeError that occurs when using a model for both inference with PtEngine and subsequent PEFT training. The change from @torch.inference_mode() to @torch.no_grad() is the appropriate solution, as it prevents model parameters from being immutably tagged as 'inference tensors', thus allowing requires_grad to be enabled later during training. While there is a potential minor performance trade-off, it is a necessary compromise to ensure the model's flexibility for mixed-use scenarios. I have added a comment to suggest documenting this important context directly in the code to aid future maintenance.
| else: | ||
| return await queue.get() | ||
|  | ||
| # Ensure `template._post_encode` has no gradient. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve maintainability and prevent future regressions, it would be beneficial to add a comment explaining why @torch.no_grad() is used here instead of the potentially more performant @torch.inference_mode(). This will clarify the reasoning for other developers who might not have the context of the GRPOTrainer RuntimeError.
# Use `torch.no_grad()` instead of `torch.inference_mode()` to allow model
# parameters to have `requires_grad=True` set on them later (e.g., during
# PEFT training). This prevents a RuntimeError when the model is used for
# both inference and training.| can't produce with the example script https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/pt.sh Could you open an issue that includes this reproduction script along with your environment details? | 
| I opted for a more subtle modification, switching from the no_grad context to inference_mode for get_logps. Can you retry? | 

PR type
PR information
During PEFT training, to obtain the logps of the reference model,
GRPOTrainerusesnull_ref_context()to temporarily disable PEFT. Upon exiting this context, re-enabling PEFT requires settingrequires_grad_(True)on each layer. However, the model was previously used inPtEnginewithin the context of@torch.inference_mode(), which results in the error "RuntimeError: Setting requires_grad=True on inference tensor outside InferenceMode is not allowed." Switching from@torch.inference_mode()to@torch.no_grad()can resolve this issue but may lead to a trade-off in performance. I'm uncertain if there is a better solution.